import os 
import matplotlib.pyplot as plt 
import torch 
import os 
import numpy as np
from PIL import Image
from torchvision import transforms

if __name__ == "__main__":
    root_dir = "./DSB细胞分割训练数据/"
    all_dir = [file for file in os.listdir(root_dir) if not file.startswith('.')] # 找到所有图片文件夹 
    for each_dir in all_dir:
        save_dir = root_dir + each_dir + "/save_masks/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        mask_dir = root_dir + each_dir + "/masks/"
        # print(mask_dir)
        mask = None 
        for index, each_mask in enumerate(os.listdir(mask_dir)):
            # print(each_mask)
            if index == 0 :
                mask = plt.imread(mask_dir + each_mask)
            else: 

                mask_img_data = plt.imread(mask_dir + each_mask)
                mask += mask_img_data
            
        plt.imsave(save_dir + each_dir + ".png", mask)
        print("success~")

    
           



